import os
import numpy as np
from itertools import product
import uuid

from math import ceil
import tqdm
from scipy.io import loadmat
import paddle
import paddle.nn.functional as F
try:
    import paddle._legacy_C_ops as C_ops
except:
    import paddle._C_ops as C_ops

def sigmoid(x):
    return 1/(1+np.exp(-x))

def softmax(x):
    max = np.max(x,axis=1,keepdims=True) #returns max of each row and keeps same dims
    e_x = np.exp(x - max) #subtracts each row with its max value
    sum = np.sum(e_x,axis=1,keepdims=True) #returns sum of each row and keeps same dims
    f_x = e_x / sum 
    return f_x

def nms_bboxes(boxes, scores):

    x = boxes[:,0]
    y = boxes[:,1]
    w = boxes[:,2] - boxes[:,0]
    h = boxes[:,3] - boxes[:,1]
    
    areas = w * h
    order = scores.argsort()[::-1]
    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)

        xx1 = np.maximum(x[i], x[order[1:]])
        yy1 = np.maximum(y[i], y[order[1:]])
        xx2 = np.minimum(x[i]+w[i], x[order[1:]]+w[order[1:]])
        yy2 = np.minimum(y[i]+h[i], y[order[1:]]+h[order[1:]])

        w1 = np.maximum(.0, xx2 - xx1 + 0.000001)
        h1 = np.maximum(.0, yy2 - yy1 + 0.000001)
        inter = w1 * h1

        iou = inter / (areas[i] + areas[order[1:]] - inter)
        inds = np.where(iou <= 0.4)[0]
        order = order[inds+1]
    
    keep = np.array(keep)
    return keep

def filter_bbox(bboxes, scores):

    pos = np.where(scores >= 0.02)
    bboxes = bboxes[pos]  
    scores = scores[pos]

    return bboxes, scores

def generate_prior_bbox(image_size, min_sizes_list=[[16, 32], [64, 128], [256, 512]], steps=[8,16,32]):
    anchors = []
    feature_maps = [[ceil(image_size[0]/step), ceil(image_size[1]/step)] for step in steps]
    for k, f in enumerate(feature_maps):
        min_sizes = min_sizes_list[k]
        for i, j in product(range(f[0]), range(f[1])):
            for min_size in min_sizes:
                s_kx = min_size / image_size[1]
                s_ky = min_size / image_size[0]
                dense_cx = [x * steps[k] / image_size[1] for x in [j + 0.5]]
                dense_cy = [y * steps[k] / image_size[0] for y in [i + 0.5]]
                for cy, cx in product(dense_cy, dense_cx):
                    anchors += [cx, cy, s_kx, s_ky] 

    # back to torch land
    output = np.array(anchors).reshape(-1, 4)
    output = output.clip(max=1, min=0)
    return output

def decode(loc, priors, variances=[0.1,0.2]):
    """Decode locations from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        loc (tensor): location predictions for loc layers,
            Shape: [num_priors,4]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded bounding box predictions
    """
    boxes_xy = priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:]
    boxes_wh = priors[:, 2:] * np.exp(loc[:, 2:] * variances[1])
    # boxes = torch.cat((
    #     priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
    #     priors[:, 2:] * np.exp(loc[:, 2:] * variances[1])), 1)
    boxes = np.concatenate((boxes_xy, boxes_wh), axis = -1)
    boxes[:, :2] -= boxes[:, 2:] / 2
    boxes[:, 2:] += boxes[:, :2]
    boxes = boxes.clip(max=1, min=0)
    return boxes

def get_retinaface_post(outputs, batch_size, params):

    channel, height, width = params["input_size"].split(",")
    outputs = [outputs[0], outputs[2], outputs[1]]

    outputs_size = params['outputs_size'].split("#")
    outputs_size_list = [ [int(size) for size in output_size.split(",")] for output_size in outputs_size]
    # outputs = [outputs[i][:batch_size*outputs_size_list[i][0]*outputs_size_list[i][1]*outputs_size_list[i][2]] for i in range(len(outputs))]
    outputs = [outputs[i].reshape(-1, outputs_size_list[i][0],outputs_size_list[i][1]) for i in range(len(outputs))]

    priorbox = generate_prior_bbox(image_size=(int(height), int(width)))

    npreds = []
    for idx in range(batch_size):
        bboxes, scores = [], [] 

        bboxes = outputs[0][idx]
        scores = outputs[2][idx][...,1]

        bboxes = decode(bboxes, priorbox)


        bboxes, scores = filter_bbox(bboxes, scores)



        keep = nms_bboxes(bboxes, scores)
        if len(keep) != 0:
            bboxes = bboxes[keep]
            scores = scores[keep]

        if len(scores) != 0:
            preds = np.concatenate((bboxes, np.expand_dims(scores, axis=-1)), axis=-1)
            npreds.append(preds)
        else:
            npreds.append(np.array([]))

    return npreds

def _to_list(l):
    if isinstance(l, (list, tuple)):
        return list(l)
    return [l]

class AnchorGeneratorSSD(object):
    def __init__(self,
                 steps=[8, 16, 32, 64, 100, 300],
                 aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
                 min_ratio=15,
                 max_ratio=90,
                 base_size=300,
                 min_sizes=[30.0, 60.0, 111.0, 162.0, 213.0, 264.0],
                 max_sizes=[60.0, 111.0, 162.0, 213.0, 264.0, 315.0],
                 offset=0.5,
                 flip=True,
                 clip=False,
                 min_max_aspect_ratios_order=False):
        self.steps = steps
        self.aspect_ratios = aspect_ratios
        self.min_ratio = min_ratio
        self.max_ratio = max_ratio
        self.base_size = base_size
        self.min_sizes = min_sizes
        self.max_sizes = max_sizes
        self.offset = offset
        self.flip = flip
        self.clip = clip
        self.min_max_aspect_ratios_order = min_max_aspect_ratios_order

        if self.min_sizes == [] and self.max_sizes == []:
            num_layer = len(aspect_ratios)
            step = int(
                math.floor(((self.max_ratio - self.min_ratio)) / (num_layer - 2
                                                                  )))
            for ratio in six.moves.range(self.min_ratio, self.max_ratio + 1,
                                         step):
                self.min_sizes.append(self.base_size * ratio / 100.)
                self.max_sizes.append(self.base_size * (ratio + step) / 100.)
            self.min_sizes = [self.base_size * .10] + self.min_sizes
            self.max_sizes = [self.base_size * .20] + self.max_sizes

        self.num_priors = []
        for aspect_ratio, min_size, max_size in zip(
                aspect_ratios, self.min_sizes, self.max_sizes):
            if isinstance(min_size, (list, tuple)):
                self.num_priors.append(
                    len(_to_list(min_size)) + len(_to_list(max_size)))
            else:
                self.num_priors.append((len(aspect_ratio) * 2 + 1) * len(
                    _to_list(min_size)) + len(_to_list(max_size)))

    def __call__(self, inputs, image):
        boxes = []
        for input, min_size, max_size, aspect_ratio, step in zip(
                inputs, self.min_sizes, self.max_sizes, self.aspect_ratios,
                self.steps):
            box, _ = self.prior_box(
                input=input,
                image=image,
                min_sizes=_to_list(min_size),
                max_sizes=_to_list(max_size),
                aspect_ratios=aspect_ratio,
                steps=[step, step])
            boxes.append(paddle.reshape(box, [-1, 4]))
        return boxes
    
    def prior_box(self, input, image, min_sizes, max_sizes=None, aspect_ratios=[1.], variance=[0.1, 0.1, 0.2, 0.2], steps=[0.0, 0.0]):
        attrs = ('min_sizes', min_sizes, 'aspect_ratios', aspect_ratios,
                 'variances', variance, 'flip', self.flip, 'clip', self.clip, 'step_w',
                 steps[0], 'step_h', steps[1], 'offset', self.offset,
                 'min_max_aspect_ratios_order', self.min_max_aspect_ratios_order)
        if max_sizes is not None and len(max_sizes) > 0:
            attrs += ('max_sizes', max_sizes)
        box, var = C_ops.prior_box(input, image, *attrs)
        return box, var

def pd_multiclass_nms(bboxes, scores, score_thr, nms_thr, keep_top_k, cls_num=80, nms_top_k=1000):
    attrs = ("background_label",cls_num,"score_threshold",score_thr,
            "nms_top_k",nms_top_k, "nms_threshold", nms_thr,
            "keep_top_k",keep_top_k, "nms_eta", 1.0, "normalized", True)

    output, _, _ = C_ops.multiclass_nms3(bboxes, scores, None, *attrs)

    return output

def get_ssd_pd_bboxes(outputs, batch_size, params):
    channel, height, width = params["input_size"].split(",")
    feat_size = [int(val) for val in params["feat_size"].split(",")]
    aspect_ratios = [[float(ratio) for ratio in ratios.split(",") ] for ratios in params["aspect_ratios"].split("#")]
    steps = [float(val) for val in params["steps"].split(",")]
    min_ratio = int(params["min_ratio"])
    if "#" in params["min_sizes"]:
        min_sizes = [[ [] if size=="-1" else float(size) for size in sizes.split(",") ] for sizes in params["min_sizes"].split("#")]
        max_sizes = [[] for sizes in params["max_sizes"].split("#")]
    else:
        min_sizes = [ [] if size=="-1" else float(size) for size in params["min_sizes"].split(",") ] 
        max_sizes = [ [] if size=="-1" else float(size) for size in params["max_sizes"].split(",") ]
    min_max_aspect_ratios_order = True if "min_max_aspect_ratios_order" in params.keys() else False
    ssd_bbox =  AnchorGeneratorSSD(aspect_ratios=aspect_ratios,steps=steps,min_ratio=min_ratio,
                    min_sizes=min_sizes, max_sizes=max_sizes,min_max_aspect_ratios_order=min_max_aspect_ratios_order)
    
    npreds = []
    for idx in range(batch_size):

        inputs_fake = [paddle.to_tensor(np.ones((1,1,int(val),int(val)),dtype=np.float32)) for val in feat_size]
        image_fake = paddle.to_tensor(np.ones((1,1,int(height),int(width)),dtype=np.float32))
        prior_boxes = ssd_bbox(inputs_fake, image_fake)
        boxes = [paddle.to_tensor(outputs[i]) for i in range(len(outputs)//2)]
        scores = [paddle.to_tensor(outputs[i+len(outputs)//2]) for i in range(len(outputs)//2)]
        boxes = paddle.concat(boxes, axis=1)
        prior_boxes = paddle.concat(prior_boxes)

        pb_w = prior_boxes[:, 2] - prior_boxes[:, 0] 
        pb_h = prior_boxes[:, 3] - prior_boxes[:, 1]
        pb_x = prior_boxes[:, 0] + pb_w * 0.5
        pb_y = prior_boxes[:, 1] + pb_h * 0.5
        out_x = pb_x + boxes[:, :, 0] * pb_w * 0.1
        out_y = pb_y + boxes[:, :, 1] * pb_h * 0.1
        out_w = paddle.exp(boxes[:, :, 2] * 0.2) * pb_w
        out_h = paddle.exp(boxes[:, :, 3] * 0.2) * pb_h
        output_boxes = paddle.stack(
            [
                out_x - out_w / 2., out_y - out_h / 2., out_x + out_w / 2.,
                out_y + out_h / 2.
            ],
            axis=-1)

        output_scores = F.softmax(paddle.concat(
            scores, axis=1)).transpose([0, 2, 1])
        
        # preds = pd_multiclass_nms(output_boxes, output_scores, 0.01, 0.3, 750, 20, 3000)
        preds = pd_multiclass_nms(output_boxes, output_scores, 0.01, 0.3, 750, 1, 5000)
        preds = preds.numpy()
        preds = np.concatenate((preds[:,2:6], preds[:,1:2]),axis=-1)
        npreds.append(preds)        
        
    return npreds

def comute_retinaface_map(preds, params, save_path, iters):
    idx = 0
    str_uuid = str(uuid.uuid1())[:8]
    with open("./data/wider_val.txt", "r") as f:
        for img_name in f.readlines()[:iters]:
            img_info = img_name.strip().split(" ")
            img_name, width, height = img_info
            save_name = (save_path + "/" + img_name[:-4] + ".txt").replace("/images","/images_"+str_uuid)
            dirname = os.path.dirname(save_name)
            if not os.path.isdir(dirname):
                os.makedirs(dirname)
            with open(save_name, "w") as fd:
                bboxs = preds[idx]
                file_name = os.path.basename(save_name)[:-4] + "\n"
                bboxs_num = str(len(bboxs)) + "\n"
                fd.write(file_name)
                fd.write(bboxs_num)
                for box in bboxs:
                    if "model_type" not in params.keys():
                        x = int(box[0] * float(width))
                        y = int(box[1] * float(height))
                        w = int((box[2] - box[0])* float(width))
                        h = int((box[3] - box[1])* float(height))  
                    else:
                        wh_max = max(float(width), float(height))
                        x = int(box[0] * wh_max)
                        y = int(box[1] * wh_max)
                        w = int((box[2] - box[0])* wh_max)
                        h = int((box[3] - box[1])* wh_max)
                    confidence = str(box[4])
                    line = str(x) + " " + str(y) + " " + str(w) + " " + str(h) + " " + confidence + " \n"
                    fd.write(line)
            
            idx += 1

    return evaluation(save_path+"/images_"+str_uuid+"/")

def evaluation(pred, gt_path="./data/ground_truth/", iou_thresh=0.5):
    pred = get_preds(pred)
    norm_score(pred)
    facebox_list, event_list, file_list, hard_gt_list, medium_gt_list, easy_gt_list = get_gt_boxes(gt_path)
    event_num = len(event_list)
    thresh_num = 1000
    settings = ['easy', 'medium', 'hard']
    setting_gts = [easy_gt_list, medium_gt_list, hard_gt_list]
    aps = []
    for setting_id in range(3):
        # different setting
        gt_list = setting_gts[setting_id]
        count_face = 0
        pr_curve = np.zeros((thresh_num, 2)).astype('float')
        # [hard, medium, easy]
        pbar = tqdm.tqdm(range(event_num))
        for i in pbar:
            pbar.set_description('Processing {}'.format(settings[setting_id]))
            event_name = str(event_list[i][0][0])
            img_list = file_list[i][0]
            pred_list = pred[event_name]
            sub_gt_list = gt_list[i][0]
            # img_pr_info_list = np.zeros((len(img_list), thresh_num, 2))
            gt_bbx_list = facebox_list[i][0]

            for j in range(len(img_list)):
                pred_info = pred_list[str(img_list[j][0][0])]

                gt_boxes = gt_bbx_list[j][0].astype('float')
                keep_index = sub_gt_list[j][0]
                count_face += len(keep_index)

                if len(gt_boxes) == 0 or len(pred_info) == 0:
                    continue
                ignore = np.zeros(gt_boxes.shape[0])
                if len(keep_index) != 0:
                    ignore[keep_index-1] = 1
                pred_recall, proposal_list = image_eval(pred_info, gt_boxes, ignore, iou_thresh)

                _img_pr_info = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)

                pr_curve += _img_pr_info
        pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)

        propose = pr_curve[:, 0]
        recall = pr_curve[:, 1]

        ap = voc_ap(recall, propose)
        aps.append(ap)

    print("==================== Results ====================")
    print("Easy   Val AP: {}".format(aps[0]))
    print("Medium Val AP: {}".format(aps[1]))
    print("Hard   Val AP: {}".format(aps[2]))
    print("=================================================")
    return aps


def bbox_overlaps(boxes_c, query_boxes_c):
    """
    Parameters
    ----------
    boxes: (N, 4) ndarray of float
    query_boxes: (K, 4) ndarray of float
    Returns
    -------
    overlaps: (N, K) ndarray of overlap between boxes and query_boxes
    """
    boxes_c = np.expand_dims(boxes_c, axis=1)
    query_boxes_c = np.expand_dims(query_boxes_c, axis=0)

    xx1 = np.maximum(boxes_c[..., 0], query_boxes_c[..., 0])
    yy1 = np.maximum(boxes_c[..., 1], query_boxes_c[..., 1])
    xx2 = np.minimum(boxes_c[..., 2], query_boxes_c[..., 2])
    yy2 = np.minimum(boxes_c[..., 3], query_boxes_c[..., 3])

    w1 = np.maximum(.0, xx2 - xx1 + 1)
    h1 = np.maximum(.0, yy2 - yy1 + 1)
    inter = w1 * h1
    area1 = (boxes_c[..., 2] - boxes_c[..., 0] + 1) * (boxes_c[..., 3] - boxes_c[..., 1] + 1)
    area2 = (query_boxes_c[..., 2] - query_boxes_c[..., 0] + 1) * (query_boxes_c[..., 3] - query_boxes_c[..., 1] + 1)
    overlaps1 = inter / ((area1 + area2) - inter + 1e-6)

    return overlaps1

def get_gt_boxes(gt_dir):
    """ gt dir: (wider_face_val.mat, wider_easy_val.mat, wider_medium_val.mat, wider_hard_val.mat)"""

    gt_mat = loadmat(os.path.join(gt_dir, 'wider_face_val.mat'))
    hard_mat = loadmat(os.path.join(gt_dir, 'wider_hard_val.mat'))
    medium_mat = loadmat(os.path.join(gt_dir, 'wider_medium_val.mat'))
    easy_mat = loadmat(os.path.join(gt_dir, 'wider_easy_val.mat'))

    facebox_list = gt_mat['face_bbx_list']
    event_list = gt_mat['event_list']
    file_list = gt_mat['file_list']

    hard_gt_list = hard_mat['gt_list']
    medium_gt_list = medium_mat['gt_list']
    easy_gt_list = easy_mat['gt_list']

    return facebox_list, event_list, file_list, hard_gt_list, medium_gt_list, easy_gt_list


def read_pred_file(filepath):

    with open(filepath, 'r') as f:
        lines = f.readlines()
        img_file = lines[0].rstrip('\n\r')
        lines = lines[2:]

    # b = lines[0].rstrip('\r\n').split(' ')[:-1]
    # c = float(b)
    # a = map(lambda x: [[float(a[0]), float(a[1]), float(a[2]), float(a[3]), float(a[4])] for a in x.rstrip('\r\n').split(' ')], lines)
    boxes = []
    for line in lines:
        line = line.rstrip('\r\n').split(' ')
        if line[0] == '':
            continue
        # a = float(line[4])
        boxes.append([float(line[0]), float(line[1]), float(line[2]), float(line[3]), float(line[4])])
    boxes = np.array(boxes)
    # boxes = np.array(list(map(lambda x: [float(a) for a in x.rstrip('\r\n').split(' ')], lines))).astype('float')
    return img_file.split('/')[-1], boxes


def get_preds(pred_dir):
    events = os.listdir(pred_dir)
    boxes = dict()
    pbar = tqdm.tqdm(events)

    for event in pbar:
        pbar.set_description('Reading Predictions ')
        event_dir = os.path.join(pred_dir, event)
        event_images = os.listdir(event_dir)
        current_event = dict()
        for imgtxt in event_images:
            imgname, _boxes = read_pred_file(os.path.join(event_dir, imgtxt))
            # imgname, _boxes = read_pred_file(imgtxt)
            current_event[imgname.rstrip('.jpg')] = _boxes
        boxes[event] = current_event
    return boxes


def norm_score(pred):
    """ norm score
    pred {key: [[x1,y1,x2,y2,s]]}
    """

    max_score = 0
    min_score = 1

    for _, k in pred.items():
        for _, v in k.items():
            if len(v) == 0:
                continue
            _min = np.min(v[:, -1])
            _max = np.max(v[:, -1])
            max_score = max(_max, max_score)
            min_score = min(_min, min_score)

    diff = max_score - min_score
    for _, k in pred.items():
        for _, v in k.items():
            if len(v) == 0:
                continue
            v[:, -1] = (v[:, -1] - min_score)/diff


def image_eval(pred, gt, ignore, iou_thresh):
    """ single image evaluation
    pred: Nx5
    gt: Nx4
    ignore:
    """

    _pred = pred.copy()
    _gt = gt.copy()
    pred_recall = np.zeros(_pred.shape[0])
    recall_list = np.zeros(_gt.shape[0])
    proposal_list = np.ones(_pred.shape[0])

    _pred[:, 2] = _pred[:, 2] + _pred[:, 0]
    _pred[:, 3] = _pred[:, 3] + _pred[:, 1]
    _gt[:, 2] = _gt[:, 2] + _gt[:, 0]
    _gt[:, 3] = _gt[:, 3] + _gt[:, 1]

    overlaps = bbox_overlaps(_pred[:, :4], _gt)

    for h in range(_pred.shape[0]):

        gt_overlap = overlaps[h]
        max_overlap, max_idx = gt_overlap.max(), gt_overlap.argmax()
        if max_overlap >= iou_thresh:
            if ignore[max_idx] == 0:
                recall_list[max_idx] = -1
                proposal_list[h] = -1
            elif recall_list[max_idx] == 0:
                recall_list[max_idx] = 1

        r_keep_index = np.where(recall_list == 1)[0]
        pred_recall[h] = len(r_keep_index)
    return pred_recall, proposal_list


def img_pr_info(thresh_num, pred_info, proposal_list, pred_recall):
    pr_info = np.zeros((thresh_num, 2)).astype('float')
    for t in range(thresh_num):

        thresh = 1 - (t+1)/thresh_num
        r_index = np.where(pred_info[:, 4] >= thresh)[0]
        if len(r_index) == 0:
            pr_info[t, 0] = 0
            pr_info[t, 1] = 0
        else:
            r_index = r_index[-1]
            p_index = np.where(proposal_list[:r_index+1] == 1)[0]
            pr_info[t, 0] = len(p_index)
            pr_info[t, 1] = pred_recall[r_index]
    return pr_info


def dataset_pr_info(thresh_num, pr_curve, count_face):
    _pr_curve = np.zeros((thresh_num, 2))
    for i in range(thresh_num):
        _pr_curve[i, 0] = pr_curve[i, 1] / pr_curve[i, 0]
        _pr_curve[i, 1] = pr_curve[i, 1] / count_face
    return _pr_curve


def voc_ap(rec, prec):

    # correct AP calculation
    # first append sentinel values at the end
    mrec = np.concatenate(([0.], rec, [1.]))
    mpre = np.concatenate(([0.], prec, [0.]))

    # compute the precision envelope
    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

    # to calculate area under PR curve, look for points
    # where X axis (recall) changes value
    i = np.where(mrec[1:] != mrec[:-1])[0]

    # and sum (\Delta recall) * prec
    ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

